package gb

import (
	"github.com/akatsuki105/dawngb/util"
	. "github.com/akatsuki105/dawngb/util/datasize"
)

type Memory struct {
	gb               *GB
	wram             [(4 * KB) * 8]uint8
	wramBank         uint
	hram             [0x7F]uint8
	ff72, ff73, ff74 uint8
}

func newMemory(gb *GB) *Memory {
	return &Memory{
		gb:       gb,
		wramBank: 1,
	}
}

func (m *Memory) Reset(hasBIOS bool) {
	for i := 0; i < len(m.wram); i++ {
		m.wram[i] = 0
	}
	for i := 0; i < len(m.hram); i++ {
		m.hram[i] = 0
	}
	m.wramBank = 1
	m.ff72, m.ff73, m.ff74 = 0, 0, 0
}

func (m *Memory) Read(addr uint16) byte {
	if m.gb.oamDMA.active {
		if addr < 0xFF80 && addr > 0xFFFE {
			return 0xFF
		}
	}

	switch addr >> 12 {
	case 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0xA, 0xB:
		return m.gb.cartridge.Read(addr)
	case 0x8, 0x9:
		return m.gb.video.Read(addr)
	case 0xC, 0xE:
		return m.wram[addr&0xFFF]
	case 0xD:
		return m.wram[(m.wramBank<<12)|uint(addr&0xFFF)]
	case 0xF:
		if addr <= 0xFDFF {
			return m.wram[(m.wramBank<<12)|uint(addr&0xFFF)]
		}
		if addr >= 0xFE00 && addr <= 0xFE9F {
			return m.gb.video.Read(addr)
		}
		switch addr {
		case 0xFF00:
			return m.gb.input.Read(addr)
		case 0xFF01, 0xFF02:
			return m.gb.serial.Read(addr)
		case 0xFF04, 0xFF05, 0xFF06, 0xFF07:
			return m.gb.timer.Read(addr)
		case 0xFF0F:
			val := uint8(0)
			for i := 0; i < 5; i++ {
				val |= (uint8(util.Btoi(m.gb.interrupt[i])) << i)
			}
			return val
		case 0xFF10, 0xFF11, 0xFF12, 0xFF13, 0xFF14, 0xFF16, 0xFF17, 0xFF18, 0xFF19, 0xFF1A, 0xFF1B, 0xFF1C, 0xFF1D, 0xFF1E, 0xFF20, 0xFF21, 0xFF22, 0xFF23, 0xFF24, 0xFF25, 0xFF26, 0xFF30, 0xFF31, 0xFF32, 0xFF33, 0xFF34, 0xFF35, 0xFF36, 0xFF37, 0xFF38, 0xFF39, 0xFF3A, 0xFF3B, 0xFF3C, 0xFF3D, 0xFF3E, 0xFF3F:
			return m.gb.audio.Read(addr)
		case 0xFF40, 0xFF41, 0xFF42, 0xFF43, 0xFF44, 0xFF45, 0xFF47, 0xFF48, 0xFF49, 0xFF4A, 0xFF4B, 0xFF4F, 0xFF68, 0xFF69, 0xFF6A, 0xFF6B:
			return m.gb.video.Read(addr)
		case 0xFF4D:
			is2x := m.gb.cpu.Cycle == 4
			val := uint8(0x7E)
			val = util.SetBit(val, 7, is2x)
			val = util.SetBit(val, 0, m.gb.key1)
			return val
		case 0xFF50:
			return 1
		case 0xFF51, 0xFF52, 0xFF53, 0xFF54, 0xFF55:
			if m.gb.cartridge.IsCGB() {
				return m.gb.dmac.Read(addr)
			}
		case 0xFF56:
			return 0x02 // TODO: infrared
		case 0xFF70:
			return uint8(m.wramBank)
		case 0xFF72:
			return m.ff72
		case 0xFF73:
			return m.ff73
		case 0xFF74:
			return m.ff74
		case 0xFF80, 0xFF81, 0xFF82, 0xFF83, 0xFF84, 0xFF85, 0xFF86, 0xFF87, 0xFF88, 0xFF89, 0xFF8A, 0xFF8B, 0xFF8C, 0xFF8D, 0xFF8E, 0xFF8F, 0xFF90, 0xFF91, 0xFF92, 0xFF93, 0xFF94, 0xFF95, 0xFF96, 0xFF97, 0xFF98, 0xFF99, 0xFF9A, 0xFF9B, 0xFF9C, 0xFF9D, 0xFF9E, 0xFF9F, 0xFFA0, 0xFFA1, 0xFFA2, 0xFFA3, 0xFFA4, 0xFFA5, 0xFFA6, 0xFFA7, 0xFFA8, 0xFFA9, 0xFFAA, 0xFFAB, 0xFFAC, 0xFFAD, 0xFFAE, 0xFFAF, 0xFFB0, 0xFFB1, 0xFFB2, 0xFFB3, 0xFFB4, 0xFFB5, 0xFFB6, 0xFFB7, 0xFFB8, 0xFFB9, 0xFFBA, 0xFFBB, 0xFFBC, 0xFFBD, 0xFFBE, 0xFFBF, 0xFFC0, 0xFFC1, 0xFFC2, 0xFFC3, 0xFFC4, 0xFFC5, 0xFFC6, 0xFFC7, 0xFFC8, 0xFFC9, 0xFFCA, 0xFFCB, 0xFFCC, 0xFFCD, 0xFFCE, 0xFFCF, 0xFFD0, 0xFFD1, 0xFFD2, 0xFFD3, 0xFFD4, 0xFFD5, 0xFFD6, 0xFFD7, 0xFFD8, 0xFFD9, 0xFFDA, 0xFFDB, 0xFFDC, 0xFFDD, 0xFFDE, 0xFFDF, 0xFFE0, 0xFFE1, 0xFFE2, 0xFFE3, 0xFFE4, 0xFFE5, 0xFFE6, 0xFFE7, 0xFFE8, 0xFFE9, 0xFFEA, 0xFFEB, 0xFFEC, 0xFFED, 0xFFEE, 0xFFEF, 0xFFF0, 0xFFF1, 0xFFF2, 0xFFF3, 0xFFF4, 0xFFF5, 0xFFF6, 0xFFF7, 0xFFF8, 0xFFF9, 0xFFFA, 0xFFFB, 0xFFFC, 0xFFFD, 0xFFFE:
			return m.hram[addr&0x7F]
		case 0xFFFF:
			return m.gb.ie
		}
	}
	return 0
}

func (m *Memory) Write(addr uint16, val byte) {
	if m.gb.oamDMA.active {
		if addr < 0xFF80 && addr > 0xFFFE {
			return
		}
	}

	switch addr >> 12 {
	case 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0xA, 0xB:
		m.gb.cartridge.Write(addr, val)
	case 0x8, 0x9:
		m.gb.video.Write(addr, val)
	case 0xC, 0xE:
		m.wram[addr&0xFFF] = val
	case 0xD:
		m.wram[(m.wramBank<<12)|uint(addr&0xFFF)] = val
	case 0xF:
		if addr <= 0xFDFF {
			m.wram[(m.wramBank<<12)|uint(addr&0xFFF)] = val
		}
		if addr >= 0xFE00 && addr <= 0xFE9F {
			m.gb.video.Write(addr, val)
			return
		}

		switch addr {
		case 0xFF00:
			m.gb.input.Write(addr, val)
		case 0xFF01, 0xFF02:
			m.gb.serial.Write(addr, val)
		case 0xFF04, 0xFF05, 0xFF06, 0xFF07:
			m.gb.timer.Write(addr, val)
		case 0xFF0F:
			for i := 0; i < 5; i++ {
				m.gb.interrupt[i] = util.Bit(val, i)
			}
		case 0xFF10, 0xFF11, 0xFF12, 0xFF13, 0xFF14, 0xFF16, 0xFF17, 0xFF18, 0xFF19, 0xFF1A, 0xFF1B, 0xFF1C, 0xFF1D, 0xFF1E, 0xFF20, 0xFF21, 0xFF22, 0xFF23, 0xFF24, 0xFF25, 0xFF26, 0xFF30, 0xFF31, 0xFF32, 0xFF33, 0xFF34, 0xFF35, 0xFF36, 0xFF37, 0xFF38, 0xFF39, 0xFF3A, 0xFF3B, 0xFF3C, 0xFF3D, 0xFF3E, 0xFF3F:
			m.gb.audio.Write(addr, val)
		case 0xFF40, 0xFF41, 0xFF42, 0xFF43, 0xFF44, 0xFF45, 0xFF47, 0xFF48, 0xFF49, 0xFF4A, 0xFF4B, 0xFF68, 0xFF69, 0xFF6A, 0xFF6B, 0xFF6C:
			m.gb.video.Write(addr, val)
		case 0xFF4F:
			if m.gb.cartridge.IsCGB() {
				m.gb.video.Write(addr, val)
			}
		case 0xFF46:
			m.gb.triggerOAMDMA(uint16(val) << 8)
		case 0xFF4D:
			m.gb.key1 = util.Bit(val, 0)
		case 0xFF51, 0xFF52, 0xFF53, 0xFF54, 0xFF55:
			if m.gb.cartridge.IsCGB() {
				m.gb.dmac.Write(addr, val)
			}
		case 0xFF70:
			if m.gb.cartridge.IsCGB() {
				m.wramBank = uint(val & 0b111)
				if m.wramBank == 0 {
					m.wramBank = 1
				}
			}
		case 0xFF72:
			m.ff72 = val
		case 0xFF73:
			m.ff73 = val
		case 0xFF74:
			m.ff74 = val
		case 0xFF80, 0xFF81, 0xFF82, 0xFF83, 0xFF84, 0xFF85, 0xFF86, 0xFF87, 0xFF88, 0xFF89, 0xFF8A, 0xFF8B, 0xFF8C, 0xFF8D, 0xFF8E, 0xFF8F, 0xFF90, 0xFF91, 0xFF92, 0xFF93, 0xFF94, 0xFF95, 0xFF96, 0xFF97, 0xFF98, 0xFF99, 0xFF9A, 0xFF9B, 0xFF9C, 0xFF9D, 0xFF9E, 0xFF9F, 0xFFA0, 0xFFA1, 0xFFA2, 0xFFA3, 0xFFA4, 0xFFA5, 0xFFA6, 0xFFA7, 0xFFA8, 0xFFA9, 0xFFAA, 0xFFAB, 0xFFAC, 0xFFAD, 0xFFAE, 0xFFAF, 0xFFB0, 0xFFB1, 0xFFB2, 0xFFB3, 0xFFB4, 0xFFB5, 0xFFB6, 0xFFB7, 0xFFB8, 0xFFB9, 0xFFBA, 0xFFBB, 0xFFBC, 0xFFBD, 0xFFBE, 0xFFBF, 0xFFC0, 0xFFC1, 0xFFC2, 0xFFC3, 0xFFC4, 0xFFC5, 0xFFC6, 0xFFC7, 0xFFC8, 0xFFC9, 0xFFCA, 0xFFCB, 0xFFCC, 0xFFCD, 0xFFCE, 0xFFCF, 0xFFD0, 0xFFD1, 0xFFD2, 0xFFD3, 0xFFD4, 0xFFD5, 0xFFD6, 0xFFD7, 0xFFD8, 0xFFD9, 0xFFDA, 0xFFDB, 0xFFDC, 0xFFDD, 0xFFDE, 0xFFDF, 0xFFE0, 0xFFE1, 0xFFE2, 0xFFE3, 0xFFE4, 0xFFE5, 0xFFE6, 0xFFE7, 0xFFE8, 0xFFE9, 0xFFEA, 0xFFEB, 0xFFEC, 0xFFED, 0xFFEE, 0xFFEF, 0xFFF0, 0xFFF1, 0xFFF2, 0xFFF3, 0xFFF4, 0xFFF5, 0xFFF6, 0xFFF7, 0xFFF8, 0xFFF9, 0xFFFA, 0xFFFB, 0xFFFC, 0xFFFD, 0xFFFE:
			m.hram[addr&0x7F] = val
		case 0xFFFF:
			m.gb.ie = val
		}
	}
}
